import torch
import numpy as np
import h5py

class SimpleHDF5Dataset:
  def __init__(self, file_handle = None):
    if file_handle == None:
      self.f = ''
      self.all_feats_dset = []
      self.all_labels = []
      self.total = 0
    else:
      self.f = file_handle
      self.all_feats_dset = self.f['all_feats'][...]
      self.all_labels = self.f['all_labels'][...]
      self.total = self.f['count'][0]
  def __getitem__(self, i):
    return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i])

  def __len__(self):
    return self.total


def init_loader(filename):
  with h5py.File(filename, 'r') as f:
    fileset = SimpleHDF5Dataset(f)

  feats = fileset.all_feats_dset
  labels = fileset.all_labels
  while np.sum(feats[-1]) == 0:
    feats  = np.delete(feats,-1,axis = 0)
    labels = np.delete(labels,-1,axis = 0)

  class_list = np.unique(np.array(labels)).tolist()
  inds = range(len(labels))

  cl_data_file = {}
  for cl in class_list:
    cl_data_file[cl] = []
  for ind in inds:
    cl_data_file[labels[ind]].append( feats[ind])

  return cl_data_file
